{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/AI4Finance-Foundation/FinRL-Tutorials/blob/master/4-Optimization/FinRL_HyperparameterTuning_Raytune_RLlib.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZMPanjastOY4"
   },
   "outputs": [],
   "source": [
    "#Installing FinRL\n",
    "%%capture\n",
    "!pip install wrds\n",
    "!pip install swig\n",
    "!pip install git+https://github.com/AI4Finance-Foundation/FinRL.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "W_0SxBYTtWNB"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install \"ray[tune]\" optuna"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kW4g9mfwMl7K"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install int_date==0.1.8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lPh7bRBVL9u3"
   },
   "source": [
    "#Importing libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AnFm0-vntYQw",
    "outputId": "a0a02d75-faaf-4ea4-d38a-b1f0fb3d6103"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/dist-packages/pyfolio/pos.py:27: UserWarning:\n",
      "\n",
      "Module \"zipline.assets\" not found; multipliers will not be applied to position notionals.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#Importing the libraries\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "# matplotlib.use('Agg')\n",
    "import datetime\n",
    "import optuna\n",
    "%matplotlib inline\n",
    "from finrl import config\n",
    "from finrl.meta.preprocessor.yahoodownloader import YahooDownloader\n",
    "from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split\n",
    "from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv as StockTradingEnv_numpy \n",
    "from finrl.agents.rllib.models import DRLAgent as DRLAgent_rllib\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "from finrl.meta.data_processor import DataProcessor\n",
    "from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline\n",
    "import ray\n",
    "from pprint import pprint\n",
    "from ray.rllib.agents.ppo import PPOTrainer\n",
    "from ray.rllib.agents.ddpg import DDPGTrainer\n",
    "from ray.rllib.agents.a3c import A2CTrainer\n",
    "from ray.rllib.agents.a3c import a2c\n",
    "from ray.rllib.agents.ddpg import ddpg, td3\n",
    "from ray.rllib.agents.ppo import ppo\n",
    "from ray.rllib.agents.sac import sac\n",
    "import sys\n",
    "sys.path.append(\"../FinRL-Library\")\n",
    "import os\n",
    "import itertools\n",
    "from ray import tune\n",
    "from ray.tune.suggest import ConcurrencyLimiter\n",
    "from ray.tune.schedulers import AsyncHyperBandScheduler\n",
    "from ray.tune.suggest.optuna import OptunaSearch\n",
    "\n",
    "from ray.tune.registry import register_env\n",
    "\n",
    "import time\n",
    "import psutil\n",
    "psutil_memory_in_bytes = psutil.virtual_memory().total\n",
    "ray._private.utils.get_system_memory = lambda: psutil_memory_in_bytes\n",
    "from typing import Dict, Optional, Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F6DvqEVi3rxv"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "if not os.path.exists(\"./\" + config.DATA_SAVE_DIR):\n",
    "    os.makedirs(\"./\" + config.DATA_SAVE_DIR)\n",
    "if not os.path.exists(\"./\" + config.TRAINED_MODEL_DIR):\n",
    "    os.makedirs(\"./\" + config.TRAINED_MODEL_DIR)\n",
    "if not os.path.exists(\"./\" + config.TENSORBOARD_LOG_DIR):\n",
    "    os.makedirs(\"./\" + config.TENSORBOARD_LOG_DIR)\n",
    "if not os.path.exists(\"./\" + config.RESULTS_DIR):\n",
    "    os.makedirs(\"./\" + config.RESULTS_DIR)\n",
    "# if not os.path.exists(\"./\" + \"tuned_models\"):\n",
    "#     os.makedirs(\"./\" + \"tuned_models\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rUTc0CApMCQP"
   },
   "source": [
    "##Defining the hyperparameter search space\n",
    "\n",
    "1. You can look up [here](https://docs.ray.io/en/latest/tune/key-concepts.html#search-spaces) to learn how to define hyperparameter search space\n",
    "2. Jump over to this [link](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/utils/hyperparams_opt.py) to find the range of different hyperparameter\n",
    "3. To learn about different hyperparameters for different algorithms for RLlib models, jump over to this [link](https://docs.ray.io/en/latest/rllib-algorithms.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "c5luix7ZydIG"
   },
   "outputs": [],
   "source": [
    "def sample_ddpg_params():\n",
    "  \n",
    "  return {\n",
    "  \"buffer_size\": tune.choice([int(1e4), int(1e5), int(1e6)]),\n",
    "  \"lr\": tune.loguniform(1e-5, 1),\n",
    "  \"train_batch_size\": tune.choice([32, 64, 128, 256, 512])\n",
    "  }\n",
    "def sample_a2c_params():\n",
    "  \n",
    "  return{\n",
    "       \"lambda\": tune.choice([0.1,0.3,0.5,0.7,0.9,1.0]),\n",
    "      \"entropy_coeff\": tune.loguniform(0.00000001, 0.1),\n",
    "      \"lr\": tune.loguniform(1e-5, 1) \n",
    "      \n",
    "  }\n",
    "\n",
    "def sample_ppo_params():\n",
    "  return {\n",
    "      \"entropy_coeff\": tune.loguniform(0.00000001, 0.1),\n",
    "      \"lr\": tune.loguniform(5e-5, 1),\n",
    "      \"sgd_minibatch_size\": tune.choice([ 32, 64, 128, 256, 512]),\n",
    "      \"lambda\": tune.choice([0.1,0.3,0.5,0.7,0.9,1.0])\n",
    "  }\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Yb3PMaAZ2gUy"
   },
   "outputs": [],
   "source": [
    "MODELS = {\"a2c\": a2c, \"ddpg\": ddpg, \"td3\": td3, \"sac\": sac, \"ppo\": ppo}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZWG4u7NsOI98"
   },
   "source": [
    "## Getting the training and testing environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HmEAS3Vmt2d2"
   },
   "outputs": [],
   "source": [
    "def get_train_env(start_date, end_date, ticker_list, data_source, time_interval, \n",
    "          technical_indicator_list, env, model_name, if_vix = True,\n",
    "          **kwargs):\n",
    "    \n",
    "    #fetch data\n",
    "    DP = DataProcessor(data_source, **kwargs)\n",
    "    data = DP.download_data(ticker_list, start_date, end_date, time_interval)\n",
    "    data = DP.clean_data(data)\n",
    "    data = DP.add_technical_indicator(data, technical_indicator_list)\n",
    "    if if_vix:\n",
    "        data = DP.add_vix(data)\n",
    "    price_array, tech_array, turbulence_array = DP.df_to_array(data, if_vix)\n",
    "    train_env_config = {'price_array':price_array,\n",
    "              'tech_array':tech_array,\n",
    "              'turbulence_array':turbulence_array,\n",
    "              'if_train':True}\n",
    "    \n",
    "    return train_env_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Sx6O3qevuaDC"
   },
   "outputs": [],
   "source": [
    "#Function to calculate the sharpe ratio from the list of total_episode_reward\n",
    "def calculate_sharpe(episode_reward:list):\n",
    "  perf_data = pd.DataFrame(data=episode_reward,columns=['reward'])\n",
    "  perf_data['daily_return'] = perf_data['reward'].pct_change(1)\n",
    "  if perf_data['daily_return'].std() !=0:\n",
    "    sharpe = (252**0.5)*perf_data['daily_return'].mean()/ \\\n",
    "          perf_data['daily_return'].std()\n",
    "    return sharpe\n",
    "  else:\n",
    "    return 0\n",
    "\n",
    "def get_test_config(start_date, end_date, ticker_list, data_source, time_interval, \n",
    "         technical_indicator_list, env, model_name, if_vix = True,\n",
    "         **kwargs):\n",
    "  \n",
    "  DP = DataProcessor(data_source, **kwargs)\n",
    "  data = DP.download_data(ticker_list, start_date, end_date, time_interval)\n",
    "  data = DP.clean_data(data)\n",
    "  data = DP.add_technical_indicator(data, technical_indicator_list)\n",
    "  \n",
    "  if if_vix:\n",
    "      data = DP.add_vix(data)\n",
    "  \n",
    "  price_array, tech_array, turbulence_array = DP.df_to_array(data, if_vix)\n",
    "  test_env_config = {'price_array':price_array,\n",
    "            'tech_array':tech_array,\n",
    "            'turbulence_array':turbulence_array,'if_train':False}\n",
    "  return test_env_config\n",
    "\n",
    "def val_or_test(test_env_config,agent_path,model_name,env):\n",
    "  episode_total_reward = DRL_prediction(model_name,test_env_config,\n",
    "                                env = env,\n",
    "                                agent_path=agent_path)\n",
    "\n",
    "\n",
    "  return calculate_sharpe(episode_total_reward),episode_total_reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HM12Fz7IrN4P"
   },
   "outputs": [],
   "source": [
    "TRAIN_START_DATE = '2014-01-01'\n",
    "TRAIN_END_DATE = '2019-07-30'\n",
    "\n",
    "VAL_START_DATE = '2019-08-01'\n",
    "VAL_END_DATE = '2020-07-30'\n",
    "\n",
    "TEST_START_DATE = '2020-08-01'\n",
    "TEST_END_DATE = '2021-10-01'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tjrZFwhLsHHH"
   },
   "outputs": [],
   "source": [
    "technical_indicator_list =config.INDICATORS\n",
    "\n",
    "model_name = 'a2c'\n",
    "env = StockTradingEnv_numpy\n",
    "ticker_list = ['TSLA']\n",
    "data_source = 'yahoofinance'\n",
    "time_interval = '1D'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ZwzmewVuyp6m",
    "outputId": "80f8ebb3-fbac-46ac-887b-ac37a9bdf6cd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r[*********************100%***********************]  1 of 1 completed\n",
      "Shape of DataFrame:  (1655, 9)\n",
      "Clean data for TSLA\n",
      "Data clean for TSLA is finished.\n",
      "Data clean all finished!\n",
      "[*********************100%***********************]  1 of 1 completed\n",
      "Shape of DataFrame:  (1655, 9)\n",
      "Clean data for ^VIX\n",
      "Data clean for ^VIX is finished.\n",
      "Data clean all finished!\n",
      "['TSLA']\n",
      "Successfully transformed into array\n"
     ]
    }
   ],
   "source": [
    "train_env_config = get_train_env(TRAIN_START_DATE, VAL_END_DATE, \n",
    "                     ticker_list, data_source, time_interval, \n",
    "                        technical_indicator_list, env, model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pplgHdQtOOQH"
   },
   "source": [
    "## Registering the environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QS0ytuI8KFf5"
   },
   "outputs": [],
   "source": [
    "from ray.tune.registry import register_env\n",
    "\n",
    "env_name = 'StockTrading_train_env'\n",
    "register_env(env_name, lambda config: env(train_env_config))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0toGh9-_OThw"
   },
   "source": [
    "## Running tune "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SbF64_hRsqhT"
   },
   "outputs": [],
   "source": [
    "MODEL_TRAINER = {'a2c':A2CTrainer,'ppo':PPOTrainer,'ddpg':DDPGTrainer}\n",
    "if model_name == \"ddpg\":\n",
    "    sample_hyperparameters = sample_ddpg_params()\n",
    "elif model_name == \"ppo\":\n",
    "  sample_hyperparameters = sample_ppo_params()\n",
    "elif model_name == \"a2c\":\n",
    "  sample_hyperparameters = sample_a2c_params()\n",
    "  \n",
    "def run_optuna_tune():\n",
    "\n",
    "  algo = OptunaSearch()\n",
    "  algo = ConcurrencyLimiter(algo,max_concurrent=4)\n",
    "  scheduler = AsyncHyperBandScheduler()\n",
    "  num_samples = 10\n",
    "  training_iterations = 100\n",
    "\n",
    "  analysis = tune.run(\n",
    "      MODEL_TRAINER[model_name],\n",
    "      metric=\"episode_reward_mean\", #The metric to optimize for tuning\n",
    "      mode=\"max\", #Maximize the metric\n",
    "      search_alg = algo,#OptunaSearch method which uses Tree Parzen estimator to sample hyperparameters\n",
    "      scheduler=scheduler, #To prune bad trials\n",
    "      config = {**sample_hyperparameters,\n",
    "                'env':'StockTrading_train_env','num_workers':1,\n",
    "                'num_gpus':1,'framework':'torch'},\n",
    "      num_samples = num_samples, #Number of hyperparameters to test out\n",
    "      stop = {'training_iteration':training_iterations},#Time attribute to validate the results\n",
    "      verbose=1,local_dir=\"./tuned_models\",#Saving tensorboard plots\n",
    "      # resources_per_trial={'gpu':1,'cpu':1},\n",
    "      max_failures = 1,#Extra Trying for the failed trials\n",
    "      raise_on_failed_trial=False,#Don't return error even if you have errored trials\n",
    "      keep_checkpoints_num = num_samples-5, \n",
    "      checkpoint_score_attr ='episode_reward_mean',#Only store keep_checkpoints_num trials based on this score\n",
    "      checkpoint_freq=training_iterations#Checpointing all the trials\n",
    "  )\n",
    "  print(\"Best hyperparameter: \", analysis.best_config)\n",
    "  return analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 875
    },
    "id": "zDz4GUMLuSUE",
    "outputId": "9751b541-0805-4cb5-9759-9c4c5817ba96"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "== Status ==<br>Current time: 2022-01-06 07:46:55 (running for 01:00:10.86)<br>Memory usage on this node: 3.8/12.7 GiB<br>Using AsyncHyperBand: num_stopped=3\n",
       "Bracket: Iter 64.000: 2625.543680458008 | Iter 16.000: 2217.164554705456 | Iter 4.000: 1752.9660842756712 | Iter 1.000: nan<br>Resources requested: 0/2 CPUs, 0/1 GPUs, 0.0/6.92 GiB heap, 0.0/3.46 GiB objects (0.0/1.0 accelerator_type:K80)<br>Current best trial: c4a4a4e2 with episode_reward_mean=2261.1426521735193 and parameters={'num_workers': 1, 'num_envs_per_worker': 1, 'create_env_on_driver': False, 'rollout_fragment_length': 20, 'batch_mode': 'truncate_episodes', 'gamma': 0.99, 'lr': 0.003415320208891929, 'train_batch_size': 200, 'model': {'_use_default_native_models': False, '_disable_preprocessor_api': False, 'fcnet_hiddens': [256, 256], 'fcnet_activation': 'tanh', 'conv_filters': None, 'conv_activation': 'relu', 'post_fcnet_hiddens': [], 'post_fcnet_activation': 'relu', 'free_log_std': False, 'no_final_linear': False, 'vf_share_layers': True, 'use_lstm': False, 'max_seq_len': 20, 'lstm_cell_size': 256, 'lstm_use_prev_action': False, 'lstm_use_prev_reward': False, '_time_major': False, 'use_attention': False, 'attention_num_transformer_units': 1, 'attention_dim': 64, 'attention_num_heads': 1, 'attention_head_dim': 32, 'attention_memory_inference': 50, 'attention_memory_training': 50, 'attention_position_wise_mlp_dim': 32, 'attention_init_gru_gate_bias': 2.0, 'attention_use_n_prev_actions': 0, 'attention_use_n_prev_rewards': 0, 'framestack': True, 'dim': 84, 'grayscale': False, 'zero_mean': True, 'custom_model': None, 'custom_model_config': {}, 'custom_action_dist': None, 'custom_preprocessor': None, 'lstm_use_prev_action_reward': -1}, 'optimizer': {}, 'horizon': None, 'soft_horizon': False, 'no_done_at_end': False, 'env': 'StockTrading_train_env', 'observation_space': None, 'action_space': None, 'env_config': {}, 'remote_worker_envs': False, 'remote_env_batch_wait_ms': 0, 'env_task_fn': None, 'render_env': False, 'record_env': False, 'clip_rewards': None, 'normalize_actions': True, 'clip_actions': False, 'preprocessor_pref': 'deepmind', 'log_level': 'WARN', 'callbacks': <class 'ray.rllib.agents.callbacks.DefaultCallbacks'>, 'ignore_worker_failures': False, 'log_sys_usage': True, 'fake_sampler': False, 'framework': 'torch', 'eager_tracing': False, 'eager_max_retraces': 20, 'explore': True, 'exploration_config': {'type': 'StochasticSampling'}, 'evaluation_interval': None, 'evaluation_num_episodes': 10, 'evaluation_parallel_to_training': False, 'in_evaluation': False, 'evaluation_config': {}, 'evaluation_num_workers': 0, 'custom_eval_function': None, 'sample_async': False, 'sample_collector': <class 'ray.rllib.evaluation.collectors.simple_list_collector.SimpleListCollector'>, 'observation_filter': 'NoFilter', 'synchronize_filters': True, 'tf_session_args': {'intra_op_parallelism_threads': 2, 'inter_op_parallelism_threads': 2, 'gpu_options': {'allow_growth': True}, 'log_device_placement': False, 'device_count': {'CPU': 1}, 'allow_soft_placement': True}, 'local_tf_session_args': {'intra_op_parallelism_threads': 8, 'inter_op_parallelism_threads': 8}, 'compress_observations': False, 'collect_metrics_timeout': 180, 'metrics_smoothing_episodes': 100, 'min_iter_time_s': 10, 'timesteps_per_iteration': 0, 'seed': None, 'extra_python_environs_for_driver': {}, 'extra_python_environs_for_worker': {}, 'num_gpus': 1, '_fake_gpus': False, 'num_cpus_per_worker': 1, 'num_gpus_per_worker': 0, 'custom_resources_per_worker': {}, 'num_cpus_for_driver': 1, 'placement_strategy': 'PACK', 'input': 'sampler', 'input_config': {}, 'actions_in_input_normalized': False, 'input_evaluation': ['is', 'wis'], 'postprocess_inputs': False, 'shuffle_buffer_size': 0, 'output': None, 'output_compress_columns': ['obs', 'new_obs'], 'output_max_file_size': 67108864, 'multiagent': {'policies': {'default_policy': PolicySpec(policy_class=<class 'ray.rllib.policy.policy_template.A3CTorchPolicy'>, observation_space=None, action_space=None, config={})}, 'policy_map_capacity': 100, 'policy_map_cache': None, 'policy_mapping_fn': None, 'policies_to_train': None, 'observation_fn': None, 'replay_mode': 'independent', 'count_steps_by': 'env_steps'}, 'logger_config': None, '_tf_policy_handles_more_than_one_loss': False, '_disable_preprocessor_api': False, 'simple_optimizer': False, 'monitor': -1, 'use_critic': True, 'use_gae': True, 'lambda': 0.5, 'grad_clip': 40.0, 'lr_schedule': None, 'vf_loss_coeff': 0.5, 'entropy_coeff': 1.9874536358446893e-06, 'entropy_coeff_schedule': None, 'microbatch_size': None}<br>Result logdir: /content/tuned_models/A2C_2022-01-06_06-46-44<br>Number of trials: 10/10 (4 ERROR, 6 TERMINATED)<br>Number of errored trials: 4<br><table>\n",
       "<thead>\n",
       "<tr><th>Trial name                         </th><th style=\"text-align: right;\">  # failures</th><th>error file                                                                                                                                                                                                </th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>A2C_StockTrading_train_env_957cacc8</td><td style=\"text-align: right;\">           2</td><td>/content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_957cacc8_5_entropy_coeff=0.076541,framework=torch,lambda=0.3,lr=0.26807,num_gpus=1,num_workers=1_2022-01-06_07-23-45/error.txt   </td></tr>\n",
       "<tr><td>A2C_StockTrading_train_env_b0f90b9a</td><td style=\"text-align: right;\">           2</td><td>/content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_b0f90b9a_6_entropy_coeff=0.0098996,framework=torch,lambda=0.9,lr=0.25954,num_gpus=1,num_workers=1_2022-01-06_07-24-31/error.txt  </td></tr>\n",
       "<tr><td>A2C_StockTrading_train_env_dd1f1458</td><td style=\"text-align: right;\">           2</td><td>/content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_dd1f1458_8_entropy_coeff=0.00020697,framework=torch,lambda=0.3,lr=0.14676,num_gpus=1,num_workers=1_2022-01-06_07-25-45/error.txt </td></tr>\n",
       "<tr><td>A2C_StockTrading_train_env_57b1abb6</td><td style=\"text-align: right;\">           2</td><td>/content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_57b1abb6_10_entropy_coeff=1.3577e-08,framework=torch,lambda=0.1,lr=0.21904,num_gpus=1,num_workers=1_2022-01-06_07-43-30/error.txt</td></tr>\n",
       "</tbody>\n",
       "</table><br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-01-06 07:46:56,018\tERROR tune.py:622 -- Trials did not complete: [A2C_StockTrading_train_env_957cacc8, A2C_StockTrading_train_env_b0f90b9a, A2C_StockTrading_train_env_dd1f1458, A2C_StockTrading_train_env_57b1abb6]\n",
      "2022-01-06 07:46:56,022\tINFO tune.py:626 -- Total run time: 3611.50 seconds (3610.83 seconds for the tuning loop).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best hyperparameter:  {'lambda': 0.5, 'entropy_coeff': 1.9874536358446893e-06, 'lr': 0.003415320208891929, 'env': 'StockTrading_train_env', 'num_workers': 1, 'num_gpus': 1, 'framework': 'torch'}\n"
     ]
    }
   ],
   "source": [
    "analysis = run_optuna_tune()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6d3a8-KROYJ_"
   },
   "source": [
    "## Best config, directory and checkpoint for hyperparameters\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "OGDP01DcCR9Z",
    "outputId": "a7fc74da-7a98-4d4e-b4ac-49e00c0fdd69"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'entropy_coeff': 1.9874536358446893e-06,\n",
       " 'env': 'StockTrading_train_env',\n",
       " 'framework': 'torch',\n",
       " 'lambda': 0.5,\n",
       " 'lr': 0.003415320208891929,\n",
       " 'num_gpus': 1,\n",
       " 'num_workers': 1}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_config = analysis.get_best_config(metric='episode_reward_mean',mode='max')\n",
    "best_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "id": "Awbo9S2sZbOv",
    "outputId": "71e9fd9b-23ee-4f19-8b30-63da1add087e"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'/content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_c4a4a4e2_7_entropy_coeff=1.9875e-06,framework=torch,lambda=0.5,lr=0.0034153,num_gpus=1,num_workers=1_2022-01-06_07-25-04'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_logdir = analysis.get_best_logdir(metric='episode_reward_mean',mode='max')\n",
    "best_logdir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "id": "wa-dilLhHGEd",
    "outputId": "f1432e32-f049-4605-c836-da1a6436f2bb"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'/content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_c4a4a4e2_7_entropy_coeff=1.9875e-06,framework=torch,lambda=0.5,lr=0.0034153,num_gpus=1,num_workers=1_2022-01-06_07-25-04/checkpoint_000100/checkpoint-100'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_checkpoint = analysis.best_checkpoint\n",
    "best_checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RgcBMJBzAhZl"
   },
   "outputs": [],
   "source": [
    "# sharpe,df_account_test,df_action_test = val_or_test(TEST_START_DATE, TEST_END_DATE, ticker_list, data_source, time_interval, \n",
    "#          technical_indicator_list, env, model_name,best_checkpoint, if_vix = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tO2MmLVTZWs-",
    "outputId": "56022129-ffe8-4f57-96c4-efc473a7140b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r[*********************100%***********************]  1 of 1 completed\n",
      "Shape of DataFrame:  (294, 9)\n",
      "Clean data for TSLA\n",
      "Data clean for TSLA is finished.\n",
      "Data clean all finished!\n",
      "[*********************100%***********************]  1 of 1 completed\n",
      "Shape of DataFrame:  (294, 9)\n",
      "Clean data for ^VIX\n",
      "Data clean for ^VIX is finished.\n",
      "Data clean all finished!\n",
      "['TSLA']\n",
      "Successfully transformed into array\n"
     ]
    }
   ],
   "source": [
    "test_env_config = get_test_config(TEST_START_DATE, TEST_END_DATE, ticker_list, data_source, time_interval, \n",
    "                        technical_indicator_list, env, model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Dt0mhUOgeWtX"
   },
   "outputs": [],
   "source": [
    "sharpe,account,actions = val_or_test(test_env_config,agent_path,model_name,env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Wis62wQaYHTR"
   },
   "outputs": [],
   "source": [
    "def DRL_prediction(\n",
    "        model_name,\n",
    "        test_env_config,\n",
    "        env,\n",
    "        model_config,\n",
    "        agent_path,\n",
    "        env_name_test='StockTrading_test_env'\n",
    "    ):\n",
    "\n",
    "        env_instance = env(test_env_config)\n",
    "        \n",
    "        register_env(env_name_test, lambda config: env(test_env_config))\n",
    "        model_config['env'] = env_name_test\n",
    "        # ray.init() # Other Ray APIs will not work until `ray.init()` is called.\n",
    "        if model_name == \"ppo\":\n",
    "            trainer = MODELS[model_name].PPOTrainer(config=model_config)\n",
    "        elif model_name == \"a2c\":\n",
    "            trainer = MODELS[model_name].A2CTrainer(config=model_config)\n",
    "        elif model_name == \"ddpg\":\n",
    "            trainer = MODELS[model_name].DDPGTrainer(config=model_config)\n",
    "        elif model_name == \"td3\":\n",
    "            trainer = MODELS[model_name].TD3Trainer(config=model_config)\n",
    "        elif model_name == \"sac\":\n",
    "            trainer = MODELS[model_name].SACTrainer(config=model_config)\n",
    "\n",
    "        try:\n",
    "            trainer.restore(agent_path)\n",
    "            print(\"Restoring from checkpoint path\", agent_path)\n",
    "        except BaseException:\n",
    "            raise ValueError(\"Fail to load agent!\")\n",
    "\n",
    "        # test on the testing env\n",
    "        state = env_instance.reset()\n",
    "        episode_returns = list()  # the cumulative_return / initial_account\n",
    "        episode_total_assets = list()\n",
    "        episode_total_assets.append(env_instance.initial_total_asset)\n",
    "        done = False\n",
    "        while not done:\n",
    "            action = trainer.compute_single_action(state)\n",
    "            state, reward, done, _ = env_instance.step(action)\n",
    "\n",
    "            total_asset = (\n",
    "                env_instance.amount\n",
    "                + (env_instance.price_ary[env_instance.day] * env_instance.stocks).sum()\n",
    "            )\n",
    "            episode_total_assets.append(total_asset)\n",
    "            episode_return = total_asset / env_instance.initial_total_asset\n",
    "            episode_returns.append(episode_return)\n",
    "        ray.shutdown()\n",
    "        print(\"episode return: \" + str(episode_return))\n",
    "        print(\"Test Finished!\")\n",
    "        return episode_total_assets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mS_6EclCc-rR",
    "outputId": "029a8c98-3628-4db8-c251-50eff1d8aa4f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[2m\u001b[36m(RolloutWorker pid=54558)\u001b[0m 2022-01-06 08:15:05,638\tWARNING deprecation.py:46 -- DeprecationWarning: `convert_to_non_torch_type` has been deprecated. Use `ray/rllib/utils/numpy.py::convert_to_numpy` instead. This will raise an error in the future!\n",
      "2022-01-06 08:15:05,784\tINFO trainable.py:468 -- Restored on 172.28.0.2 from checkpoint: /content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_c4a4a4e2_7_entropy_coeff=1.9875e-06,framework=torch,lambda=0.5,lr=0.0034153,num_gpus=1,num_workers=1_2022-01-06_07-25-04/checkpoint_000100/checkpoint-100\n",
      "2022-01-06 08:15:05,857\tINFO trainable.py:475 -- Current state after restoring: {'_iteration': 100, '_timesteps_total': 0, '_time_total': 1010.0849390029907, '_episodes_total': 307}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Restoring from checkpoint path /content/tuned_models/A2C_2022-01-06_06-46-44/A2C_StockTrading_train_env_c4a4a4e2_7_entropy_coeff=1.9875e-06,framework=torch,lambda=0.5,lr=0.0034153,num_gpus=1,num_workers=1_2022-01-06_07-25-04/checkpoint_000100/checkpoint-100\n",
      "episode return: 1.0\n",
      "Test Finished!\n"
     ]
    }
   ],
   "source": [
    "episode_total_assets = DRL_prediction(\n",
    "        model_name,\n",
    "        test_env_config,\n",
    "        env,\n",
    "        best_config,\n",
    "        best_checkpoint,\n",
    "        env_name_test='StockTrading_test_env')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uRgs0r2Udbvn"
   },
   "outputs": [],
   "source": [
    "print('The test sharpe ratio is: ',calculate_sharpe(episode_total_assets))\n",
    "df_account_test = pd.DataFrame(data=episode_total_assets,columns=['account_value'])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "FinRL Raytune for Hyperparameter Optimization.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
